# datasets/base.py
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Dict, Any, Protocol, Tuple

import numpy as np

@dataclass
class DatasetSpec:
    """
    Standard container for a dataset used by CTMC benchmarks.

    Attributes
    ----------
    name : str
        Canonical name for the dataset (e.g., "pbmc3k").
    X : np.ndarray
        Feature matrix (cells x features) in ambient/preprocessed space.
        Convention: float32.
    labels : Optional[np.ndarray]
        Cell type or condition labels (length == n_cells).
    batch : Optional[np.ndarray]
        Batch/experiment IDs (length == n_cells), used for mixing metrics.
    meta : Optional[Dict[str, Any]]
        Arbitrary metadata: e.g., feature names, obs_df, raw counts, file paths.
    """
    name: str
    X: np.ndarray
    labels: Optional[np.ndarray] = None
    batch: Optional[np.ndarray] = None
    meta: Optional[Dict[str, Any]] = None


class Loader(Protocol):
    """Protocol for dataset loader callables."""

    def __call__(self, cache_dir: Optional[str] = None, **kwargs) -> DatasetSpec: ...


class Preprocessor(Protocol):
    """Protocol for preprocessing callables."""
    def __call__(
        self,
        X: np.ndarray,
        *,
        normalize: bool = True,
        log1p: bool = True,
        hvg_n: Optional[int] = 2000,
        pca_n: Optional[int] = 50,
        random_state: int = 0,
    ) -> Tuple[np.ndarray, Dict[str, Any]]: ...


import sys
sys.path.insert(0,'/content/drive/MyDrive/embeddings')

# datasets/base.py


from functools import partial
from urllib.parse import urlparse, unquote
import re, h5py, os
import anndata as ad
from sklearn.datasets import make_moons, load_digits, make_checkerboard, make_swiss_roll, fetch_openml, make_s_curve, make_circles
from zipfile import ZipFile
from pathlib import Path
from tqdm import tqdm
import numpy as np
import struct
from array import array
from os.path  import join


def chessboard_cloud(n=8, per_side=20, pad=0.0, mode="grid", rng=None):
    """
    Create a point cloud that fills each tile of an n x n chessboard.

    Parameters
    ----------
    n : int
        Board size (n x n).
    per_side : int
        Points along one side of each tile (grid mode). Total per tile = per_side**2.
        In 'random' mode this controls density via per_side**2 points per tile.
    pad : float
        Padding margin inside each tile, in tile units, 0 <= pad < 0.5.
        E.g., pad=0.02 leaves a thin gap between neighboring tiles.
    mode : {'grid','random'}
        'grid' for a regular lattice inside each tile, 'random' for uniform random.
    rng : None | int | np.random.Generator
        Random seed or Generator (used only in 'random' mode).

    Returns
    -------
    X : (N, 2) float ndarray
        2D coordinates for scatter (tiles span [i-0.5, i+0.5] x [j-0.5, j+0.5]).
    y : (N,) int ndarray
        Binary labels per point: 0 = black, 1 = white. (0,0) tile is black.
    """
    if not (0.0 <= pad < 0.5):
        raise ValueError("pad must be in [0, 0.5).")

    # Tile centers on integer grid
    cx, cy = np.meshgrid(np.arange(n), np.arange(n))
    centers = np.column_stack([cx.ravel(), cy.ravel()])
    labels = ((cx + cy) % 2).ravel()  # 0 black, 1 white (A1 at (0,0) is black)

    if mode == "grid":
        u = np.linspace(-0.5 + pad, 0.5 - pad, per_side)
        U, V = np.meshgrid(u, u)  # local coords inside a tile
        local = np.column_stack([U.ravel(), V.ravel()])  # (per_side**2, 2)
        X = np.repeat(centers, per_side * per_side, axis=0) + np.tile(local, (n * n, 1))
        y = np.repeat(labels, per_side * per_side)

    elif mode == "random":
        m = per_side * per_side
        if rng is None or isinstance(rng, (int, np.integer)):
            rng = np.random.default_rng(rng)
        low, high = -0.5 + pad,  0.5 - pad
        jitter = rng.uniform(low, high, size=(n * n * m, 2))
        X = np.repeat(centers, m, axis=0) + jitter
        y = np.repeat(labels, m)

    else:
        raise ValueError("mode must be 'grid' or 'random'")

    return X.astype(np.float32), y.astype(np.int8)



def _filename_from_cd(content_disposition: str | None) -> str | None:
    """Extract filename from a Content-Disposition header (RFC 6266 / 5987)."""
    if not content_disposition:
        return None
    m = re.search(r'filename\*=\s*UTF-8\'\'([^\s;]+)', content_disposition, flags=re.IGNORECASE)
    if m:
        return unquote(m.group(1))
    m = re.search(r'filename="?([^";]+)"?', content_disposition, flags=re.IGNORECASE)
    if m:
        return m.group(1)
    return None



def _download_file(url: str,
                  dest_path: str | Path | None = None,
                  *,
                  overwrite: bool = False,
                  chunk_size: int = 1 << 20,   # 1 MiB
                  timeout: int = 30,
                  show_progress: bool = True) -> Path:
    """
    Download a URL to disk and return the saved Path.

    - If dest_path is a directory (or None), the filename is auto-detected.
    - Set overwrite=True to replace an existing file.
    - Streams in chunks to avoid loading the file into memory.
    - If show_progress is True, displays a progress bar (uses tqdm when available).
    """
    # Optional tqdm import
    if show_progress:
        try:
            from tqdm import tqdm  # type: ignore
        except Exception:
            tqdm = None
    else:
        tqdm = None

    # Prefer 'requests' if available; otherwise fall back to stdlib.
    try:
        import requests  # type: ignore
    except ImportError:
        requests = None

    def _resolve_dest(suggested_name: str) -> Path:
        base = Path(dest_path).expanduser() if dest_path else Path.cwd()
        if base.is_dir() or dest_path is None:
            base.mkdir(parents=True, exist_ok=True)
            return base / suggested_name
        base.parent.mkdir(parents=True, exist_ok=True)
        return base

    def _progress_tqdm(total: int | None, desc: str):
        if tqdm is None:
            return None  # no-op
        # total can be None for unknown length
        return tqdm(total=total if total and total > 0 else None,
                    unit="B", unit_scale=True, unit_divisor=1024,
                    desc=desc, leave=False)

    def _progress_basic(total: int | None):
        # simple fallback when tqdm isn't available
        class _PB:
            def __init__(self, total):
                self.total = total
                self.n = 0
                self._last_pct = -1
            def update(self, inc):
                self.n += inc
                if self.total:
                    pct = int(self.n * 100 / self.total)
                    if pct != self._last_pct:
                        self._last_pct = pct
                        print(f"\rDownloading... {pct:3d}%", end="", flush=True)
                else:
                    # unknown length
                    dots = "." * ((self.n // (1 << 20)) % 10)
                    print(f"\rDownloading{dots}", end="", flush=True)
            def close(self):
                if self.total:
                    print("\rDownloading... 100%")
                else:
                    print("\rDone.           ")
        if tqdm is None and show_progress:
            return _PB(total)
        return None

    if requests:
        with requests.get(url, stream=True, timeout=timeout) as resp:
            resp.raise_for_status()
            # Prefer server-provided filename, otherwise derive from URL.
            name = _filename_from_cd(resp.headers.get("content-disposition"))
            if not name:
                path = urlparse(resp.url).path
                name = Path(unquote(path)).name or "download"
            dest = _resolve_dest(name)
            if dest.exists() and not overwrite:
                raise FileExistsError(f"{dest} already exists (set overwrite=True to replace).")

            total = resp.headers.get("content-length")
            total_int = int(total) if total and total.isdigit() else None
            bar = _progress_tqdm(total_int, name) or _progress_basic(total_int)

            try:
                with open(dest, "wb") as f:
                    for chunk in resp.iter_content(chunk_size=chunk_size):
                        if not chunk:
                            continue
                        f.write(chunk)
                        if bar:
                            bar.update(len(chunk))
            finally:
                if bar:
                    bar.close()
            return dest
    else:
        from urllib.request import urlopen, Request
        req = Request(url, headers={"User-Agent": "python-downloader/1.0"})
        with urlopen(req, timeout=timeout) as resp:
            name = _filename_from_cd(resp.headers.get("Content-Disposition"))
            if not name:
                path = urlparse(resp.geturl()).path
                name = Path(unquote(path)).name or "download"
            dest = _resolve_dest(name)
            if dest.exists() and not overwrite:
                raise FileExistsError(f"{dest} already exists (set overwrite=True to replace).")

            total = resp.headers.get("Content-Length")
            total_int = int(total) if total and total.isdigit() else None
            bar = _progress_tqdm(total_int, name) or _progress_basic(total_int)

            try:
                with open(dest, "wb") as f:
                    while True:
                        chunk = resp.read(chunk_size)
                        if not chunk:
                            break
                        f.write(chunk)
                        if bar:
                            bar.update(len(chunk))
            finally:
                if bar:
                    bar.close()
            return dest




def _process_openml(dataset: str, version: int):
    data = fetch_openml(name=dataset, version=version)
    X, y = data['data'], data['target']

    return X.astype(float), y



def _process_digits():
    digits = load_digits()
    return (digits['data'], digits['target'])


def _process_anndata(filename: str, save_path: str = './data/'):
    return ad.read_h5ad(join(save_path, filename))



def load_dataset(
    dataset: str,
    save_path: str = "./data/",
):
    """
    Load anything and everything.

    Brief description of datasets / references:
    TODO

    """

    DATA_PATHS = {
        'chronocellsim' : ('https://github.com/pachterlab/FGP_2024/raw/refs/heads/main/data/sim_disconnect.h5ad', partial(_process_anndata, 'sim_disconnect.h5ad')), # URL will go here
        'vu' : ('https://www.dropbox.com/scl/fi/rd24bhlp0urorxs4499y4/vu_2022_ay_wh.h5ad?rlkey=jafslgnqsjcz2ascvaxu7shph&st=ju5bp2t8&dl=1', partial(_process_anndata, 'vu_2022_ay_wh.h5ad')),
        'pbmc3k' : ('https://www.dropbox.com/scl/fi/5pbb6g8ya4v1try3501yl/pbmc3k_preprocessed.h5ad?rlkey=oqlas7k9obq6szgce847psuzn&st=33x9atw2&dl=1', partial(_process_anndata, 'pbmc3k_preprocessed.h5ad')),
        'monocytedrug': ('https://www.dropbox.com/scl/fi/xob63ct6idfxofb4z2jg6/monocyte_drug_preporocessed.h5ad?rlkey=srlkr4i1ipmzpt4r7whx9kf7y&st=w84tf0u0&dl=1', partial(_process_anndata, 'monocyte_drug_preprocessed.h5ad')),
        'lamanno' : ('https://www.dropbox.com/scl/fi/h012yvnlnx652hhy4d4cg/lamanno_processed.h5ad?rlkey=zhwi22t8gqymye4yqn4w3ebg8&st=zf1yevsm&dl=1', partial(_process_anndata, 'lamanno_preprocessed.h5ad')),
        'kang' : ('https://www.dropbox.com/scl/fi/zxta2nf00p8a9do907rrv/kang_et_al_perturtbations_preprocessed.h5ad?rlkey=bk4pbuily0349borou6rnvjds&st=zycw73z2&dl=1', partial(_process_anndata, 'kang_et_al_perturbations_preprocessed.h5ad')),
        'forebrain' : ('https://www.dropbox.com/scl/fi/ph8u2rowr63cqsps8noxl/hgForebrainGlut_preprocessed.h5ad?rlkey=hdvy4mq7mpuvey360j8lxltvr&st=nd3b8f98&dl=1', partial(_process_anndata, 'hgForebrainGlut_preprocessed.h5ad')),
        'decipher' : ('https://www.dropbox.com/scl/fi/wnbkj7ero8mlk4bfo142m/data_decipher_tutorial_preprocessed.h5ad?rlkey=uaevk4idf516udsl4r44984m0&st=jnn9hi88&dl=1', partial(_process_anndata, 'data_decipher_tutorial_preprocessed.h5ad')),
        'celegans' : ('https://www.dropbox.com/scl/fi/7gd1uo6n4ahnpv7ltit03/celegans_waterston_preprocessed.h5ad?rlkey=u493o77af0p9ugynq16euyn4f&st=6jkx9mki&dl=1', partial(_process_anndata, 'celegans_waterston_preprocessed.h5ad')),
        'moons' : (None, partial(make_moons, n_samples=3000, shuffle=True, noise=0.1, random_state=42)),
        'digits' : (None, _process_digits),
        'checkerboard' : (None, partial(chessboard_cloud, rng=42)),
        'swiss_roll' : (None, partial(make_swiss_roll, n_samples=3000, random_state=42)),
        'scurve': (None, partial(make_s_curve, n_samples=3000, random_state=42)),
        'circles': (None, partial(make_circles, n_samples=3000, noise=0.1, random_state=42, factor=0.8)),
        'usps' : (None, partial(_process_openml, dataset='USPS', version=2)),
        'mnist' : (None, partial(_process_openml, dataset='mnist_784', version=1)),
        'semeion': (None, partial(_process_openml, dataset='semeion', version=2)),
        'isolet': (None, partial(_process_openml, dataset='isolet', version=1)),
        'phoneme': (None, partial(_process_openml, dataset='phoneme', version=1)),
        'fashion_mnist': (None, partial(_process_openml, dataset='Fashion-MNIST', version=1)),
        'cc_fraud': (None, partial(_process_openml, dataset='CreditCardFraudDetection', version=1)),
    }

    assert dataset in DATA_PATHS.keys(), f"{dataset} not recognized as a dataset."
    url, preprocessor = DATA_PATHS[dataset]

    if isinstance(url,  str):
        try:
            _download_file(url, save_path)

        except FileExistsError:
            pass

        return preprocessor(save_path)


    return preprocessor()
